// Copyright © 2022 Cisco Systems, Inc. and its affiliates.
// All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package vulnerability

import (
	"math"
	"strings"

	log "github.com/sirupsen/logrus"
)

const (
	DEFCON1    = "DEFCON1"
	CRITICAL   = "CRITICAL"
	HIGH       = "HIGH"
	MEDIUM     = "MEDIUM"
	LOW        = "LOW"
	NEGLIGIBLE = "NEGLIGIBLE"
	UNKNOWN    = "UNKNOWN"
	NONE       = "NONE"
	INFO       = "INFO"
)

type SeverityModel[T any] struct {
	Critical   T
	High       T
	Medium     T
	Low        T
	Negligible T
}

func (m SeverityModel[T]) GetVulnerabilitySeverityFromString(severity string) T {
	switch strings.ToUpper(severity) {
	case DEFCON1, CRITICAL:
		return m.Critical
	case HIGH:
		return m.High
	case MEDIUM:
		return m.Medium
	case LOW:
		return m.Low
	case NEGLIGIBLE, UNKNOWN, NONE, INFO:
		return m.Negligible
	default:
		log.Warnf("Unknown severity: %v", severity)
		return m.Negligible
	}
}

func GetCVSSV3VersionFromVector(cvssV3Vector string) string {
	// example input "CVSS:3.0/AV:N/AC:L/PR:N/UI:N/S:U/C:L/I:L/A:H"
	// expected result "3.0"
	return strings.TrimPrefix(strings.Split(cvssV3Vector, "/")[0], "CVSS:")
}

// nolint:mnd
var v3ScoreValues = map[string]map[string]float64{
	"AV": {
		"N": 0.85,
		"A": 0.62,
		"L": 0.55,
		"P": 0.2,
	},
	"AC": {
		"L": 0.77,
		"H": 0.44,
	},
	"PR": {
		"N": 0.85,
		"L": 0.62,
		"H": 0.27,
	},
	// Modified scope
	"PR-MS": {
		"N": 0.85,
		"L": 0.68,
		"H": 0.5,
	},
	"UI": {
		"N": 0.85,
		"R": 0.62,
	},
	"C": {
		"H": 0.56,
		"L": 0.22,
		"N": 0.0,
	},
	"I": {
		"H": 0.56,
		"L": 0.22,
		"N": 0.0,
	},
	"A": {
		"H": 0.56,
		"L": 0.22,
		"N": 0.0,
	},
}

// https://www.first.org/cvss/v3.1/specification-document#7-1-Base-Metrics-Equations
// https://www.first.org/cvss/v3.0/specification-document#CVSS-v3-0-Equations
// nolint:mnd
func ExploitScoreAndImpactScoreFromV3Vector(cvssV3Vector string) (float64, float64) {
	parts := strings.Split(cvssV3Vector, "/")
	metrics := map[string]string{}
	for _, part := range parts {
		tpy, score, _ := strings.Cut(part, ":")
		metrics[tpy] = score
	}

	iss := 1.0 - ((1.0 - v3ScoreValues["C"][metrics["C"]]) * (1.0 - v3ScoreValues["I"][metrics["I"]]) * (1.0 - v3ScoreValues["A"][metrics["A"]]))

	// Calculate impact
	var impact float64
	if metrics["S"] == "U" {
		impact = 6.42 * iss
	} else {
		impact = 7.52*(iss-0.029) - 3.25*math.Pow(iss-0.02, 15)
	}

	// Calculate exploitability
	privMetric := "PR"
	if metrics["S"] == "C" {
		privMetric = "PR-MS"
	}
	exploitability := 8.22 * v3ScoreValues["AV"][metrics["AV"]] * v3ScoreValues["AC"][metrics["AC"]] * v3ScoreValues[privMetric][metrics["PR"]] * v3ScoreValues["UI"][metrics["UI"]]

	return RoundToOneDecimalPlace(exploitability), RoundToOneDecimalPlace(impact)
}

// nolint:mnd
func RoundToOneDecimalPlace(score float64) float64 {
	return math.Round(score*10.0) / 10.0
}

// nolint:mnd
var v2ScoreValues = map[string]map[string]float64{
	"AV": {
		"L": 0.395,
		"A": 0.646,
		"N": 1.0,
	},
	"AC": {
		"L": 0.71,
		"M": 0.61,
		"H": 0.35,
	},
	"Au": {
		"M": 0.45,
		"S": 0.56,
		"N": 0.704,
	},
	"C": {
		"N": 0.0,
		"P": 0.275,
		"C": 0.660,
	},
	"I": {
		"N": 0.0,
		"P": 0.275,
		"C": 0.660,
	},
	"A": {
		"N": 0.0,
		"P": 0.275,
		"C": 0.660,
	},
}

// nolint:mnd
func ExploitScoreAndImpactScoreFromV2Vector(cvssV2Vector string) (float64, float64) {
	parts := strings.Split(cvssV2Vector, "/")
	metrics := map[string]string{}
	for _, part := range parts {
		tpy, score, _ := strings.Cut(part, ":")
		metrics[tpy] = score
	}

	iss := 1.0 - ((1.0 - v2ScoreValues["C"][metrics["C"]]) * (1.0 - v2ScoreValues["I"][metrics["I"]]) * (1.0 - v2ScoreValues["A"][metrics["A"]]))
	impact := 10.41 * iss
	exploit := 20.0 * v2ScoreValues["AV"][metrics["AV"]] * v2ScoreValues["AC"][metrics["AC"]] * v2ScoreValues["Au"][metrics["Au"]]

	return RoundToOneDecimalPlace(exploit), RoundToOneDecimalPlace(impact)
}
